{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Optimization with an Image Classification Example\n",
    "1. [Introduction](#Introduction)\n",
    "2. [Prerequisites and Preprocessing](#Prequisites-and-Preprocessing)\n",
    "3. [Training the model](#Training-the-model)\n",
    "4. [Inference with the vanilla model](#Inference-with-the-vanilla-model)\n",
    "5. [Inference with optimized model](#Inference-with-the-optimized-model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "\n",
    "Welcome to our model optimization example for image classification.\n",
    "\n",
    "In this demo, we will use the Amazon sagemaker image classification algorithm to train on the [caltech-256 dataset](http://www.vision.caltech.edu/Image_Datasets/Caltech256/). \n",
    "\n",
    "To get started, we need to set up the environment with a few prerequisite steps, for permissions, configurations, and so on."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prequisites and Preprocessing\n",
    "\n",
    "### Permissions and environment variables\n",
    "\n",
    "Here we set up the linkage and authentication to AWS services. There are three parts to this:\n",
    "\n",
    "* The roles used to give learning and hosting access to your data. This will automatically be obtained from the role used to start the notebook\n",
    "* The S3 bucket that you want to use for training and model data\n",
    "* The Amazon sagemaker image classification docker image which need not be changed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "from sagemaker import get_execution_role\n",
    "\n",
    "role = get_execution_role()\n",
    "sess = sagemaker.Session()\n",
    "bucket=sess.default_bucket()\n",
    "prefix = 'ic'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.amazon.amazon_estimator import get_image_uri\n",
    "\n",
    "training_image = get_image_uri(sess.boto_region_name, 'image-classification', repo_version=\"latest\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data preparation\n",
    "Download the data and transfer to S3 for use in training. In this demo, we are using [Caltech-256](http://www.vision.caltech.edu/Image_Datasets/Caltech256/) dataset, which contains 30608 images of 256 objects. For the training and validation data, we follow the splitting scheme in this MXNet [example](https://github.com/apache/incubator-mxnet/blob/master/example/image-classification/data/caltech256.sh). In particular, it randomly selects 60 images per class for training, and uses the remaining data for validation. The algorithm takes `RecordIO` file as input. The user can also provide the image files as input, which will be converted into `RecordIO` format using MXNet's [im2rec](https://mxnet.incubator.apache.org/how_to/recordio.html?highlight=im2rec) tool. It takes around 50 seconds to converted the entire Caltech-256 dataset (~1.2GB) on a p2.xlarge instance. However, for this demo, we will use record io format. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "import urllib.request\n",
    "import boto3\n",
    "\n",
    "def download(url):\n",
    "    filename = url.split(\"/\")[-1]\n",
    "    if not os.path.exists(filename):\n",
    "        urllib.request.urlretrieve(url, filename)\n",
    "\n",
    "        \n",
    "def upload_to_s3(channel, file):\n",
    "    s3 = boto3.resource('s3')\n",
    "    data = open(file, \"rb\")\n",
    "    key = channel + '/' + file\n",
    "    s3.Bucket(bucket).put_object(Key=key, Body=data)\n",
    "\n",
    "\n",
    "# caltech-256\n",
    "download('http://data.mxnet.io/data/caltech-256/caltech-256-60-train.rec')\n",
    "download('http://data.mxnet.io/data/caltech-256/caltech-256-60-val.rec')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Four channels: train, validation, train_lst, and validation_lst\n",
    "s3train = 's3://{}/{}/train/'.format(bucket, prefix)\n",
    "s3validation = 's3://{}/{}/validation/'.format(bucket, prefix)\n",
    "\n",
    "# upload the lst files to train and validation channels\n",
    "!aws s3 cp caltech-256-60-train.rec $s3train --quiet\n",
    "!aws s3 cp caltech-256-60-val.rec $s3validation --quiet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "Once we have the data available in the correct format for training, the next step is to actually train the model using the data. After setting training parameters, we kick off training, and poll for status until training is completed.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training the model\n",
    "\n",
    "Now that we are done with all the setup that is needed, we are ready to train our object detector. To begin, let us create a ``sageMaker.estimator.Estimator`` object. This estimator will launch the training job.\n",
    "### Training parameters\n",
    "There are two kinds of parameters that need to be set for training. The first one are the parameters for the training job. These include:\n",
    "\n",
    "* **Training instance count**: This is the number of instances on which to run the training. When the number of instances is greater than one, then the image classification algorithm will run in distributed settings. \n",
    "* **Training instance type**: This indicates the type of machine on which to run the training. Typically, we use GPU instances for these training \n",
    "* **Output path**: This the s3 folder in which the training output is stored\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s3_output_location = 's3://{}/{}/output'.format(bucket, prefix)\n",
    "ic = sagemaker.estimator.Estimator(training_image,\n",
    "                                         role, \n",
    "                                         train_instance_count=1, \n",
    "                                         train_instance_type='ml.p3.8xlarge', \n",
    "                                         train_volume_size = 50,\n",
    "                                         train_max_run = 360000,\n",
    "                                         input_mode= 'File',\n",
    "                                         output_path=s3_output_location,\n",
    "                                         sagemaker_session=sess)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Apart from the above set of parameters, there are hyperparameters that are specific to the algorithm. These are:\n",
    "\n",
    "* **num_layers**: The number of layers (depth) for the network. We use 18 in this samples but other values such as 50, 152 can be used.\n",
    "* **image_shape**: The input image dimensions,'num_channels, height, width', for the network. It should be no larger than the actual image size. The number of channels should be same as the actual image.\n",
    "* **num_classes**: This is the number of output classes for the new dataset. Imagenet was trained with 1000 output classes but the number of output classes can be changed for fine-tuning. For caltech, we use 257 because it has 256 object categories + 1 clutter class.\n",
    "* **num_training_samples**: This is the total number of training samples. It is set to 15240 for caltech dataset with the current split.\n",
    "* **mini_batch_size**: The number of training samples used for each mini batch. In distributed training, the number of training samples used per batch will be N * mini_batch_size where N is the number of hosts on which training is run.\n",
    "* **epochs**: Number of training epochs.\n",
    "* **learning_rate**: Learning rate for training.\n",
    "* **top_k**: Report the top-k accuracy during training.\n",
    "* **precision_dtype**: Training datatype precision (default: float32). If set to 'float16', the training will be done in mixed_precision mode and will be faster than float32 mode\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ic.set_hyperparameters(num_layers=18,\n",
    "                             image_shape = \"3,224,224\",\n",
    "                             num_classes=257,\n",
    "                             num_training_samples=15420,\n",
    "                             mini_batch_size=128,\n",
    "                             epochs=5,\n",
    "                             learning_rate=0.01,\n",
    "                             top_k=2,\n",
    "                             precision_dtype='float32')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Input data specification\n",
    "Set the data type and channels used for training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = sagemaker.session.s3_input(s3train, distribution='FullyReplicated', \n",
    "                        content_type='application/x-recordio', s3_data_type='S3Prefix')\n",
    "validation_data = sagemaker.session.s3_input(s3validation, distribution='FullyReplicated', \n",
    "                             content_type='application/x-recordio', s3_data_type='S3Prefix')\n",
    "\n",
    "data_channels = {'train': train_data, 'validation': validation_data}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Start the training\n",
    "Start training by calling the fit method in the estimator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "ic.fit(inputs=data_channels, logs=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference with the vanilla model\n",
    "\n",
    "***\n",
    "\n",
    "Now we will test the trained model without any specific optimization for the hardware."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ic_classifier = ic.deploy(initial_instance_count = 1,\n",
    "                          instance_type = 'ml.c5.4xlarge')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download test image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget -O /tmp/test.jpg http://www.vision.caltech.edu/Image_Datasets/Caltech256/images/080.frog/080_0001.jpg\n",
    "file_name = '/tmp/test.jpg'\n",
    "# test image\n",
    "from IPython.display import Image\n",
    "Image(file_name)  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Evaluation\n",
    "\n",
    "Evaluate the image through the network for inteference. The network outputs class probabilities and typically, one selects the class with the maximum probability as the final class output.\n",
    "\n",
    "**Note:** The output class detected by the network may not be accurate in this example. To limit the time taken and cost of training, we have trained the model only for 5 epochs. If the network is trained for more epochs (say 20), then the output class will be more accurate."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import numpy as np\n",
    "\n",
    "with open(file_name, 'rb') as f:\n",
    "    payload = f.read()\n",
    "    payload = bytearray(payload)\n",
    "    \n",
    "ic_classifier.content_type = 'application/x-image'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Time the prediction with the vanilla model.**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "result = json.loads(ic_classifier.predict(payload))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the result will output the probabilities for all classes\n",
    "# find the class with maximum probability and print the class index\n",
    "index = np.argmax(result)\n",
    "object_categories = ['ak47', 'american-flag', 'backpack', 'baseball-bat', 'baseball-glove', 'basketball-hoop', 'bat', 'bathtub', 'bear', 'beer-mug', 'billiards', 'binoculars', 'birdbath', 'blimp', 'bonsai-101', 'boom-box', 'bowling-ball', 'bowling-pin', 'boxing-glove', 'brain-101', 'breadmaker', 'buddha-101', 'bulldozer', 'butterfly', 'cactus', 'cake', 'calculator', 'camel', 'cannon', 'canoe', 'car-tire', 'cartman', 'cd', 'centipede', 'cereal-box', 'chandelier-101', 'chess-board', 'chimp', 'chopsticks', 'cockroach', 'coffee-mug', 'coffin', 'coin', 'comet', 'computer-keyboard', 'computer-monitor', 'computer-mouse', 'conch', 'cormorant', 'covered-wagon', 'cowboy-hat', 'crab-101', 'desk-globe', 'diamond-ring', 'dice', 'dog', 'dolphin-101', 'doorknob', 'drinking-straw', 'duck', 'dumb-bell', 'eiffel-tower', 'electric-guitar-101', 'elephant-101', 'elk', 'ewer-101', 'eyeglasses', 'fern', 'fighter-jet', 'fire-extinguisher', 'fire-hydrant', 'fire-truck', 'fireworks', 'flashlight', 'floppy-disk', 'football-helmet', 'french-horn', 'fried-egg', 'frisbee', 'frog', 'frying-pan', 'galaxy', 'gas-pump', 'giraffe', 'goat', 'golden-gate-bridge', 'goldfish', 'golf-ball', 'goose', 'gorilla', 'grand-piano-101', 'grapes', 'grasshopper', 'guitar-pick', 'hamburger', 'hammock', 'harmonica', 'harp', 'harpsichord', 'hawksbill-101', 'head-phones', 'helicopter-101', 'hibiscus', 'homer-simpson', 'horse', 'horseshoe-crab', 'hot-air-balloon', 'hot-dog', 'hot-tub', 'hourglass', 'house-fly', 'human-skeleton', 'hummingbird', 'ibis-101', 'ice-cream-cone', 'iguana', 'ipod', 'iris', 'jesus-christ', 'joy-stick', 'kangaroo-101', 'kayak', 'ketch-101', 'killer-whale', 'knife', 'ladder', 'laptop-101', 'lathe', 'leopards-101', 'license-plate', 'lightbulb', 'light-house', 'lightning', 'llama-101', 'mailbox', 'mandolin', 'mars', 'mattress', 'megaphone', 'menorah-101', 'microscope', 'microwave', 'minaret', 'minotaur', 'motorbikes-101', 'mountain-bike', 'mushroom', 'mussels', 'necktie', 'octopus', 'ostrich', 'owl', 'palm-pilot', 'palm-tree', 'paperclip', 'paper-shredder', 'pci-card', 'penguin', 'people', 'pez-dispenser', 'photocopier', 'picnic-table', 'playing-card', 'porcupine', 'pram', 'praying-mantis', 'pyramid', 'raccoon', 'radio-telescope', 'rainbow', 'refrigerator', 'revolver-101', 'rifle', 'rotary-phone', 'roulette-wheel', 'saddle', 'saturn', 'school-bus', 'scorpion-101', 'screwdriver', 'segway', 'self-propelled-lawn-mower', 'sextant', 'sheet-music', 'skateboard', 'skunk', 'skyscraper', 'smokestack', 'snail', 'snake', 'sneaker', 'snowmobile', 'soccer-ball', 'socks', 'soda-can', 'spaghetti', 'speed-boat', 'spider', 'spoon', 'stained-glass', 'starfish-101', 'steering-wheel', 'stirrups', 'sunflower-101', 'superman', 'sushi', 'swan', 'swiss-army-knife', 'sword', 'syringe', 'tambourine', 'teapot', 'teddy-bear', 'teepee', 'telephone-box', 'tennis-ball', 'tennis-court', 'tennis-racket', 'theodolite', 'toaster', 'tomato', 'tombstone', 'top-hat', 'touring-bike', 'tower-pisa', 'traffic-light', 'treadmill', 'triceratops', 'tricycle', 'trilobite-101', 'tripod', 't-shirt', 'tuning-fork', 'tweezer', 'umbrella-101', 'unicorn', 'vcr', 'video-projector', 'washing-machine', 'watch-101', 'waterfall', 'watermelon', 'welding-mask', 'wheelbarrow', 'windmill', 'wine-bottle', 'xylophone', 'yarmulke', 'yo-yo', 'zebra', 'airplanes-101', 'car-side-101', 'faces-easy-101', 'greyhound', 'tennis-shoes', 'toad', 'clutter']\n",
    "print(\"Result: label - \" + object_categories[index] + \", probability - \" + str(result[index]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Clean up\n",
    "\n",
    "\n",
    "When we're done with the endpoint, we can just delete it and the backing instances will be released. Uncomment and run the following cell to delete the endpoint and model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ic_classifier.delete_endpoint()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference with the optimized model\n",
    "\n",
    "***\n",
    "\n",
    "Now we will test the trained model without any specific optimization for the hardware."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Optimize the model specifically for the architecture\n",
    "Now we will compare the same model, but compiled specifically for the architecture we're deploying on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "output_path = '/'.join(ic.output_path.split('/')[:-1])\n",
    "optimized_ic = ic.compile_model(target_instance_family='ml_c5', \n",
    "                                input_shape={'data':[1, 3, 224, 224]},  # Batch size 1, 3 channels, 224x224 Images.\n",
    "                                output_path=output_path,\n",
    "                                framework='mxnet', framework_version='1.2.1')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# There is a known issue where SageMaker SDK locates the incorrect docker image URI for Image Classification\n",
    "# For now, we manually set Image URI\n",
    "optimized_ic.image = get_image_uri(sess.boto_region_name, 'image-classification-neo', repo_version=\"latest\")\n",
    "# There is a known issue where SageMaker SDK does not set the same. In the mean time we set the name\n",
    "optimized_ic.name = 'deployed-image-classification'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Deploy optimized model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "optimized_ic_classifier = optimized_ic.deploy(initial_instance_count = 1,\n",
    "                                              instance_type = 'ml.c5.4xlarge')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimized_ic_classifier.content_type = 'application/x-image'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Time the prediction with the optimized model.** We compare the optimized predictions times to the vanilla predictions times."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "result = json.loads(optimized_ic_classifier.predict(payload))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the result will output the probabilities for all classes\n",
    "# find the class with maximum probability and print the class index\n",
    "index = np.argmax(result)\n",
    "print(\"Result: label - \" + object_categories[index] + \", probability - \" + str(result[index]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Clean up\n",
    "\n",
    "\n",
    "When we're done with the endpoint, we can just delete it and the backing instances will be released. Uncomment and run the following cell to delete the endpoint and model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimized_ic_classifier.delete_endpoint()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  },
  "notice": "Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.  Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
